/*      > C.Set - Set data type */

#include <stddef.h>
#include <stdlib.h>
#include <string.h>
#include "set.h"

#ifdef test
#include <stdio.h>
#endif

struct link
{
        struct link *next;
        char data[1];
};

typedef struct link *link;

/* Return values from functions */

#define OK      1
#define ERR     0

/* Utility function - find an element in a set */

static link find (const set s, const void *element, link *prev)
{
        link this;
        link p = NULL;
        const int size = s->obj_size;

        for ( this = s->first; this != NULL; this = this->next )
        {
                if ( memcmp(element,this->data,size) == 0 )
                {
                        if ( prev != NULL )
                                *prev = p;
                                
                        return this;
                }

                p = this;
        }

        return NULL;
}

/* General component routines */

set set_new (int obj_len)
{
        register set s;

        s = malloc(sizeof(struct set));

        if ( s == NULL )
                return NULL;

        s->first    = NULL;
        s->obj_size = obj_len;

        return s;
}

void set_free (set s)
{
        set_clear(s);
        free(s);
}

void set_clear (set s)
{
        link this_entry = s->first;
        link next_entry;
        
        while ( this_entry != NULL )
        {
                next_entry = this_entry->next;
                free(this_entry);
                this_entry = next_entry;
        }

        s->first = NULL;
}

int set_copy (set s1, const set s2)
{
        link p;
        link new;
        link last;
        int size;

        if ( s1->obj_size != s2->obj_size )
                return ERR;

        set_clear(s1);

        last = (link)s1;
        size = s2->obj_size;

        for ( p = s2->first; p != NULL; p = p->next )
        {
                new = malloc(sizeof(struct link) - 1 + size);
                if ( new == NULL )
                {
                        set_clear(s1);
                        return ERR;
                }
                last->next = new;
                memcpy(new->data,p->data,size);
                last = new;
        }

        last->next = NULL;
        return OK;
}

int set_equal (const set s1, const set s2)
{
        link p;
        int n1 = 0;
        int n2 = 0;

        if ( s1->obj_size != s2->obj_size )
                return 0;

        /* For every element of s1, look for it in s2 */

        for ( p = s1->first; p != NULL; p = p->next )
        {
                /* If it's not in s2, the sets are different */
                if ( find(s2,p->data,NULL) == NULL )
                        return 0;

                /* Count the elements of s1 */
                ++n1;
        }

        /* Count the elements of s2 */
        n2 = set_size(s2);

        /* The sets differ if there are elements in s1 not in s2 */
        return ( n1 == n2 );
}

int set_empty (const set s)
{
        return ( s->first == NULL );
}

int set_size (const set s)
{
        int i = 0;
        link p;

        for ( p = s->first; p != NULL; p = p->next )
                ++i;

        return i;
}

int set_iterate (const set s, int (*process)(void *))
{
        int ret = 0;
        link p;

        for ( p = s->first; p != NULL; p = p->next )
        {
                ret = (*process)(p->data);

                /* Non-zero => stop processing here */

                if ( ret != 0 )
                        break;
        }

        /* Negative => Abnormal (error) termination */

        return ( ret >= 0 );
}

/* set-specific routines */

int set_add (set s, const void *object)
{
        link new;

        if ( find(s,object,NULL) != NULL )
                return ERR;

        new = malloc(sizeof(struct link) - 1 + s->obj_size);

        if ( new == NULL )
                return ERR;

        memcpy(new->data,object,s->obj_size);

        new->next = s->first;
        s->first = new;

        return OK;
}

int set_remove (set s, const void *object)
{
        link p;
        link prev;

        p = find(s,object,&prev);

        if ( p == NULL )
                return ERR;

        if ( prev == NULL )
                s->first = p->next;
        else
                prev->next = p->next;

        free(p);

        return OK;
}

int set_member (const set s, const void *object)
{
        return ( find(s,object,NULL) != NULL );
}

int set_union (set s, const set t, const set u)
{
        link p;
        link new;

        /* Check with t's length occurs in set_copy */
        if ( s->obj_size != u->obj_size )
                return ERR;

        if ( !set_copy(s,t) )
                return ERR;

        for ( p = u->first; p != NULL; p = p->next )
        {
                if ( find(s,p->data,NULL) != NULL )
                        continue;

                new = malloc(sizeof(struct link) - 1 + s->obj_size);

                if ( new == NULL )
                        return ERR;

                memcpy(new->data,p->data,s->obj_size);

                new->next = s->first;
                s->first = new;
        }

        return OK;
}

int set_intersection (set s, const set t, const set u)
{
        link p;
        link new;

        if ( s->obj_size != t->obj_size || s->obj_size != u->obj_size )
                return ERR;

        set_clear(s);

        for ( p = t->first; p != NULL; p = p->next )
        {
                if ( find(u,p->data,NULL) == NULL )
                        continue;

                new = malloc(sizeof(struct link) - 1 + s->obj_size);

                if ( new == NULL )
                        return ERR;

                memcpy(new->data,p->data,s->obj_size);

                new->next = s->first;
                s->first = new;
        }

        return OK;
}

int set_difference (set s, const set t, const set u)
{
        link p;
        link new;

        if ( s->obj_size != t->obj_size || s->obj_size != u->obj_size )
                return ERR;

        set_clear(s);

        for ( p = t->first; p != NULL; p = p->next )
        {
                if ( find(u,p->data,NULL) != NULL )
                        continue;

                new = malloc(sizeof(struct link) - 1 + s->obj_size);

                if ( new == NULL )
                        return ERR;

                memcpy(new->data,p->data,s->obj_size);

                new->next = s->first;
                s->first = new;
        }

        return OK;
}

int set_subset (const set s1, const set s2)
{
        link p;

        if ( s1->obj_size != s2->obj_size )
                return 0;

        /* For every element of s1, look for it in s2 */

        for ( p = s1->first; p != NULL; p = p->next )
        {
                /* If it's not in s2, s1 is not a subset */
                if ( find(s2,p->data,NULL) == NULL )
                        return 0;
        }

        return 1;
}

int set_proper_subset (const set s1, const set s2)
{
        link p;
        int n1 = 0;
        int n2 = 0;

        if ( s1->obj_size != s2->obj_size )
                return 0;

        /* For every element of s1, look for it in s2 */

        for ( p = s1->first; p != NULL; p = p->next )
        {
                /* If it's not in s2, s1 is not a subset */
                if ( find(s2,p->data,NULL) == NULL )
                        return 0;

                /* Count the elements of s1 */
                ++n1;
        }

        /* Count the elements of s2 */
        n2 = set_size(s2);

        /* It is only a proper subset if there are elements of s2 not in s1 */
        return ( n1 < n2 );
}

/*---------------------------------------------------------------------------*/

#ifdef test
int print (void *ptr)
{
        printf("%d ",*(int *)ptr);
        return STATUS_CONTINUE;
}

void set_dump (set s)
{
        printf("set: ");
        set_iterate(s,print);
        putchar('\n');
}
#endif

/*---------------------------------------------------------------------------*/

#ifdef test

#define BUFLEN 255

int main (void)
{
        char buf[BUFLEN];
        int i, j, k, num;
        set s[10];

        for ( i = 0; i < 10; ++i )
                s[i] = set_new(sizeof(int));

        for ( ; ; )
        {
                printf(">");
                fgets(buf,BUFLEN,stdin);

                if ( buf[0] == '\n' || buf[0] == '\0' )
                        continue;
                else if ( sscanf(buf,"clear %1d",&i) == 1 )
                        set_clear(s[i]);
                else if ( sscanf(buf,"copy %1d %1d",&i,&j) == 2 )
                        set_copy(s[i],s[j]);
                else if ( sscanf(buf,"equal %1d %1d",&i,&j) == 2 )
                        printf("%s\n",(set_equal(s[i],s[j]) ? "yes" : "no"));
                else if ( sscanf(buf,"empty %1d",&i) == 1 )
                        printf("%s\n",(set_empty(s[i]) ? "yes" : "no"));
                else if ( sscanf(buf,"size %1d",&i) == 1 )
                        printf("%d\n",set_size(s[i]));
                else if ( sscanf(buf,"dump %1d",&i) == 1 )
                        set_dump(s[i]);
                else if ( sscanf(buf,"add %1d %d",&i,&num) == 2 )
                        set_add(s[i],&num);
                else if ( sscanf(buf,"remove %1d %d",&i,&num) == 2 )
                        set_remove(s[i],&num);
                else if ( sscanf(buf,"member %1d %d",&i,&num) == 2 )
                        printf("%s\n", set_member(s[i],&num) ? "yes" : "no");
                else if ( sscanf(buf,"union %1d %1d %1d",&i,&j,&k) == 3 )
                        set_union(s[i],s[j],s[k]);
                else if ( sscanf(buf,"intersection %1d %1d %1d",&i,&j,&k) == 3 )
                        set_intersection(s[i],s[j],s[k]);
                else if ( sscanf(buf,"difference %1d %1d %1d",&i,&j,&k) == 3 )
                        set_difference(s[i],s[j],s[k]);
                else if ( sscanf(buf,"subset %1d %1d",&i,&j) == 2 )
                        printf("%s\n", set_subset(s[i],s[j]) ? "yes" : "no");
                else if ( sscanf(buf,"proper subset %1d %1d",&i,&j) == 2 )
                        printf("%s\n", set_proper_subset(s[i],s[j]) ? "yes" : "no");
                else if ( strncmp(buf,"help",4) == 0 )
                        printf(
                                "clear i\n"
                                "copy i j\n"
                                "equal i j\n"
                                "empty i\n"
                                "size i\n"
                                "dump i\n"
                                "add i n\n"
                                "remove i n\n"
                                "member i n\n"
                                "union i j k\n"
                                "intersection i j k\n"
                                "difference i j k\n"
                                "subset i j\n"
                                "proper subset i j\n"
                              );
                else if ( strncmp(buf,"quit",4) == 0 )
                        break;
                else
                        printf("Mistake\n");
        }

        printf("Deleting s[0-9]\n");
        for ( i = 0; i < 10; ++i )
                set_free(s[i]);

        return 0;
}

#endif
